import argparse
from pathlib import Path
from tqdm import tqdm

# torch

import torch

from einops import repeat

# vision imports

from PIL import Image
from torchvision.utils import make_grid, save_image

# dalle related classes and utils

# from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE, DALLE
from dalle_pytorch.dalle_pytorch_ori import DALLE_PG_Discrete, DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE, DALLE_PG, DALLE_PG_Discrete, DiscretePGVAE
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, YttmTokenizer, ChineseTokenizer
from IPython import embed
import os
from pytorch3d.io import save_ply, load_ply

import sys
sys.path.insert(0, '/home/tiangel/DALLE_3D/Learning-to-Group')
from shaper.models.pointnet2.modules import PointNetSAModule, PointnetFPModule
from partnet.utils.torch_pc import normalize_points as normalize_points_torch

from geometry_utils import render_pts, rotate_pts, render_pts_with_label
import torch.nn.functional as F
# argument parsing

parser = argparse.ArgumentParser()

parser.add_argument('--dalle_path', type = str, required = True,
                    help='path to your trained DALL-E')

parser.add_argument('--vqgan_model_path', type=str, default = None,
                   help='path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)')

parser.add_argument('--vqgan_config_path', type=str, default = None,
                   help='path to your trained VQGAN config. This should be a .yaml file.  (only valid when taming option is enabled)')

parser.add_argument('--pts_path', type = str, required=True,
                    help='your text prompt')

parser.add_argument('--text', type = str, default = None,
                    help='your text prompt')

parser.add_argument('--num_images', type = int, default = 128, required = False,
                    help='number of images')

parser.add_argument('--batch_size', type = int, default = 4, required = False,
                    help='batch size')

parser.add_argument('--top_k', type = float, default = 0.9, required = False,
                    help='top k filter threshold')

parser.add_argument('--outputs_dir', type = str, default = './outputs/dalle_outputs', required = False,
                    help='output directory')

parser.add_argument('--save_name', type = str, default = '1', help = 'KL loss weight')

parser.add_argument('--bpe_path', type = str,
                    help='path to your huggingface BPE json file')

parser.add_argument('--hug', dest='hug', action = 'store_true')

parser.add_argument('--chinese', dest='chinese', action = 'store_true')

parser.add_argument('--taming', dest='taming', action='store_true')

parser.add_argument('--gentxt', dest='gentxt', action='store_true')

args = parser.parse_args()

# helper fns

def exists(val):
    return val is not None

# tokenizer

if exists(args.bpe_path):
    klass = HugTokenizer if args.hug else YttmTokenizer
    tokenizer = klass(args.bpe_path)
elif args.chinese:
    tokenizer = ChineseTokenizer()

# load DALL-E

dalle_path = Path(os.path.join('./outputs/dalle_models',args.dalle_path))

assert dalle_path.exists(), 'trained DALL-E must exist'

load_obj = torch.load(str(dalle_path))
dalle_params, vae_params, pgvae_params, weights, vae_class_name, version = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('pgvae_params'), load_obj.pop('weights'), load_obj.pop('vae_class_name', None), load_obj.pop('version', None)

# friendly print

if exists(version):
    print(f'Loading a model trained with DALLE-pytorch version {version}')
else:
    print('You are loading a model trained on an older version of DALL-E pytorch - it may not be compatible with the most recent version')

# load VAE

if args.taming:
    vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path)
elif vae_params is not None:
    vae = DiscreteVAE(**vae_params)
else:
    vae = OpenAIDiscreteVAE()

pgvae = DiscretePGVAE(**pgvae_params)

assert not (exists(vae_class_name) and vae.__class__.__name__ != vae_class_name), f'you trained DALL-E using {vae_class_name} but are trying to generate with {vae.__class__.__name__} - please make sure you are passing in the correct paths and settings for the VAE to use for generation'

# reconstitute DALL-E

dalle = DALLE_PG(vae = vae, pgvae = pgvae, **dalle_params).cuda()
# dalle = DALLE_PG_Discrete(vae = vae, pgvae = pgvae, **dalle_params).cuda()

dalle.load_state_dict(weights)

# generate images

image_size = vae.image_size

# texts = args.text.split('|')


save_dir = os.path.join('./shape2prog/vqprogram_outputs','test'+args.save_name)
if not os.path.exists(save_dir):
    os.mkdir(save_dir)
pts_save_dir = os.path.join('./shape2prog/vqprogram_outputs','test'+args.save_name, 'pts')
if not os.path.exists(pts_save_dir):
    os.mkdir(pts_save_dir)
# pc = load_ply('/home/tiangel/datasets/val_pc_10000p_6k/'+'a1d217ba806367cbc13a0d88b632af1d.ply')
pc_path = os.path.join('/home/tiangel/datasets/val_pc_10000p_6k/', args.pts_path)
pc = load_ply(pc_path)
points = normalize_points_torch(pc[0].unsqueeze(0)).cuda()
points = repeat(points, '() n c -> b n c', b = args.num_images)

outputs = []
for pts_chunk in tqdm(points.split(args.batch_size), desc = f'generating images for - {pc_path}'):
    output = dalle.generate_pgs(pts_chunk, filter_thres = args.top_k)
    outputs.append(output)
out = torch.cat(outputs)

out_pgm = out[:,:,:22].reshape(-1, 10, 3, 22)
out_pgm = F.log_softmax(out_pgm, dim=-1)
_, out_pgm = torch.max(out_pgm, dim=-1)
out_param = out[:,:,22:].reshape(-1, 10, 3, 7)

save_obj = {
    'pgm': out_pgm,
    'param': out_param,
}
print('save_dir:%s'%save_dir)
torch.save(save_obj, os.path.join(save_dir,'%04d'%(0)+'.pt'))
save_ply(os.path.join(pts_save_dir,'%04d'%0+'_ori.ply'), pc[0])

exit()

for j, text in tqdm(enumerate(texts)):
    if args.gentxt:
        text_tokens, gen_texts = dalle.generate_texts(tokenizer, text=text, filter_thres = args.top_k)
        text = gen_texts[0]
    else:
        text_tokens = tokenizer.tokenize([text], dalle.text_seq_len).cuda()

    text_tokens = repeat(text_tokens, '() n -> b n', b = args.num_images)

    outputs = []

    for text_chunk in tqdm(text_tokens.split(args.batch_size), desc = f'generating images for - {text}'):
        output = dalle.generate_images(text_chunk, filter_thres = args.top_k)
        outputs.append(output)

    outputs = torch.cat(outputs)

    # save all images

    file_name = text 
    outputs_dir = Path(args.outputs_dir) / file_name.replace(' ', '_')[:(100)]
    outputs_dir.mkdir(parents = True, exist_ok = True)

    for i, image in tqdm(enumerate(outputs), desc = 'saving images'):
        save_ply(os.path.join(outputs_dir,'%04d.ply'%i),image)
        # save_image(image, outputs_dir / f'{i}.jpg', normalize=True)
        with open(outputs_dir / 'caption.txt', 'w') as f:
            f.write(file_name)

    print(f'created {args.num_images} images at "{str(outputs_dir)}"')
